## Loading required package: lpSolve
## 
## Attaching package: 'salso'
## The following object is masked from 'package:mcclust':
## 
##     binder
## 
## Attaching package: 'ggpubr'
## The following object is masked from 'package:WASABI':
## 
##     ggscatter
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union

Load MAPseq data

The data is from Chen et al (2019) describing single-neuron axon projection counts to 11 brain areas: the orbitofrontal cortex (OFC), motor cortex (Motor), rostral striatum (Rstr), somatosensory cortex (SSctx), caudal striatum (Cstr), amygdala (Amyg), ipsilateral visual cortex (VisIp), contralateral visual cortex (VisC), contralateral auditory cortex (AudC), thalamus (Thal), and tectum (Tect). Data are collected across three brains (mice). For illustration, we focus on the first and third brain which are extracted using the same technology (BARseq).

data("data_barseq")

data_barseq = list(data_barseq[[1]], data_barseq[[3]])

M <- length(data_barseq)
R <- nrow(data_barseq[[1]])
regions.name <- rownames(data_barseq[[1]])

C <- sapply(1:M, function(m) ncol(data_barseq[[m]]))

mouse.index <- c(rep(1, C[1]),
                 rep(2, C[2]))

Let’s visualize a heatmap of the data. Data are normalized by the total counts to refect projection strength.

Run the MCMC algorithm

HBMAP employs a hierarchcal mixture of Dirichlet-Multinomials to model axon projection data.

# Set the truncation 
J = 35

# ---- parameters to pass to the main function ------
# mcmc setup
mcmc_list = list(number_iter = 20000, thinning = 5, burn_in = 10000, adaptive_prop = 0.0001,
                 auto_save = FALSE,
                 save_path = NULL,
                 save_frequency = 1000
                 )
# prior parameters, default values will be used if not provided
prior_list = list(a_gamma = 20, b_gamma = 1, lb_gamma = 1, 
                  a = 1, tau = 0.2, nu =  1/1000,
                  a_alpha = 10, b_alpha = 1, a_alpha0 = 5, b_alpha0 = 1)

# Initialization
set.seed(43)
Z.init <- k_means_axon(Y = data_barseq, k = 30, transformation = 'cosine', restart = 50, iter.max = 100)

# ------- Run the full model ---------
seeds = c(101, 112, 323, 141, 555)
mcmc_all_barseq = parallel::mclapply(1:5, 
                                   function(g){
                                     set.seed(g)
                                     HBMAP_mcmc(Y = data_barseq, J = J, mcmc = mcmc_list, 
                                                prior = prior_list, Z.init = Z.init, verbose = TRUE)
                                   },
                                   mc.cores = 5)
cls.draw.list = lapply(1:5, function(g){
  d = mcmc_all_barseq[[g]]$Z_output
  matrix(unlist(d),length(d), sum(C), byrow = TRUE)
})                                   
cls.draw = do.call(rbind, cls.draw.list)

Clustering

Optimal clustering estimates

Let’s start by considering the optimal clustering solution obtained with different loss functions as well as the marginal posterior on the number of clusters.

# Relabel 
relabel = function(c){
  uu = unique(c)
  c2 = c
  for(i in 1:length(uu)){
    c2[c == uu[i]] = i
  }
  c2
}

cls.draw = t(apply(cls.draw,1,relabel))
S = dim(cls.draw)[1]

# Marginal posterior on the number of clusters
K.draw = apply(cls.draw,1,max)

# Compute psm
psm = mcclust::comp.psm(cls.draw)

# Estimate clustering and compare different loss functions

#VI
output_salso = salso(x = cls.draw, maxZealousAttempts=50) 

# Binder's loss
output_salso_binder = salso(x = cls.draw, loss = "binder", maxNClusters = 100, maxZealousAttempts=50) 

# ARI
output_salso_ari = salso(x = cls.draw, loss = "omARI",maxNClusters = 100, maxZealousAttempts=50) 
# Marginal posterior on the number of clusters
ggplot()+
  geom_bar(aes(x=K.draw))+
  theme_bw() +
  labs(x="number of clusters")

print(paste('Mean of marginal posterior on the number of clusters:', mean(K.draw)))
## [1] "Mean of marginal posterior on the number of clusters: 26.8953523238381"
#VI
print(paste('Number of clusters in vi estimate:', length(unique(output_salso))))
## [1] "Number of clusters in vi estimate: 27"
# Illustrate clustering with psm
superheat(psm,
          pretty.order.rows = TRUE,
          pretty.order.cols = TRUE,
          heat.pal = c("white", "yellow", "red"),
          heat.pal.values = c(0,.5,1),
          membership.rows = output_salso,
          membership.cols = output_salso,
          bottom.label.text.size = 4,
          left.label.text.size = 4)

# Illustrate clustering with heatmap of row-normalized data
vi_list = lapply(1:M,function(m){output_salso[mouse.index==m]})
heatmap_ps(Y = data_barseq, Z = vi_list, regions.name = rownames(data_barseq[[1]]), 
           group.index = mouse.index, group.name = 'brain',
           cluster.index = 1:length(unique(output_salso)), title = '')

# Binder's loss
print(paste('Number of clusters in binders estimate:', length(unique(output_salso_binder))))
## [1] "Number of clusters in binders estimate: 50"
# Illustrate clustering with psm
superheat(psm,
          pretty.order.rows = TRUE,
          pretty.order.cols = TRUE,
          heat.pal = c("white", "yellow", "red"),
          heat.pal.values = c(0,.5,1),
          membership.rows = output_salso_binder,
          membership.cols = output_salso_binder,
          bottom.label.text.size = 4,
          left.label.text.size = 4)

# Illustrate clustering with heatmap of row-normalized data
binder_list = lapply(1:M,function(m){output_salso_binder[mouse.index==m]})
heatmap_ps(Y = data_barseq, Z = binder_list, regions.name = rownames(data_barseq[[1]]), 
           group.index = mouse.index, group.name = 'brain',
           cluster.index = 1:length(unique(output_salso_binder)), title = '')

# ARI
print(paste('Number of clusters in ari estimate:', length(unique(output_salso_ari))))
## [1] "Number of clusters in ari estimate: 37"
# Illustrate clustering with psm
superheat(psm,
          pretty.order.rows = TRUE,
          pretty.order.cols = TRUE,
          heat.pal = c("white", "yellow", "red"),
          heat.pal.values = c(0,.5,1),
          membership.rows = output_salso_ari,
          membership.cols = output_salso_ari,
          bottom.label.text.size = 4,
          left.label.text.size = 4)

# Illustrate clustering with heatmap of row-normalized data
ari_list = lapply(1:M,function(m){output_salso_ari[mouse.index==m]})
heatmap_ps(Y = data_barseq, Z = ari_list, regions.name = rownames(data_barseq[[1]]), 
           group.index = mouse.index, group.name = 'brain',
           cluster.index = 1:length(unique(output_salso_ari)), title = '')

Different results are obtained in this case. Binder and ARI lead to a large number of clusters, with many small clusters. VI is more parsimonious. Let’s summarize with WASABI to better understand if there are multiple modes of clustering.

WASABI

We use the elbow function to choose the number of particles \(L\) with the elbow method:

set.seed(123)
L_max = 10
tic()
out_elbow <- elbow(cls.draw, L_max = L_max, multi.start = 4,
                   method.init = "++",
                   mini.batch = 500, max.iter = 20, extra.iter = 4, 
                   method = "salso")
## Completed  1 / 10 
## Completed  2 / 10 
## Completed  3 / 10 
## Completed  4 / 10 
## Completed  5 / 10 
## Completed  6 / 10 
## Completed  7 / 10 
## Completed  8 / 10 
## Completed  9 / 10 
## Completed  10 / 10
toc()
## 3042.824 sec elapsed
L= 3
ggplot() + 
  geom_point(aes(x=c(1:L_max), y=out_elbow$wass_vec)) +
  geom_line(aes(x=c(1:L_max), y=out_elbow$wass_vec)) +
  labs(x="Number of particles", y="Wasserstein distance") +
  annotate("point", x = L, y = out_elbow$wass_vec[L], color = "red", shape = 1, size = 3) + 
  theme_bw()

Once the value of \(L\) is chosen, we can run another set of initializations to see if we can find a better approximation:

tic()
output_WASABI_mb = WASABI_multistart(cls.draw, psm, 
                                    multi.start = 50, ncores = 5,
                                    method.init ="++", add_topvi = FALSE,
                                    method="salso", L=L,
                                    mini.batch = 500,
                                    max.iter= 20, extra.iter = 10,
                                    swap_countone = TRUE,
                                    maxNClusters = 45, maxZealousAttempts=20,  
                                    seed = 54321)
toc()
## 9156.843 sec elapsed
output_WASABI <- out_elbow$output_list[[L]]
if(output_WASABI_mb$wass.dist < output_WASABI$wass.dist){
  output_WASABI <- output_WASABI_mb
  print(paste('Improved approximation with multiple initialization: Wass dist =',output_WASABI$wass.dist))
}
## [1] "Improved approximation with multiple initialization: Wass dist = 0.937498746173342"
tic()
output_WASABI_avg = WASABI(cls.draw, psm, method.init ="average", 
                           method="salso", L=L, mini.batch = 500,
                           maxNClusters = 45, maxZealousAttempts=20,
                           max.iter= 20, extra.iter = 10, swap_countone = TRUE,
                           suppress.comment = FALSE)
## Initial particle 1 : number of clusters = 24 , EVI = 0.987 
##  Initial particle 2 : number of clusters = 23 , EVI = 0.994 
##  Initial particle 3 : number of clusters = 22 , EVI = 0.997 
## Iteration = 1 
## Particle 1 : number of clusters=27 , EVI = 0.972 , sumVI = 0.745 , w= 0.766 
##  Particle 2 : number of clusters=29 , EVI = 0.992 , sumVI = 0.212 , w= 0.214 
##  Particle 3 : number of clusters=26 , EVI = 0.899 , sumVI = 0.018 , w= 0.02 
## Wasserstein dist = 0.975023835414756 
## Iteration = 2 
## Particle 1 : number of clusters=27 , EVI = 0.959 , sumVI = 0.593 , w= 0.618 
##  Particle 2 : number of clusters=28 , EVI = 0.994 , sumVI = 0.316 , w= 0.318 
##  Particle 3 : number of clusters=25 , EVI = 0.907 , sumVI = 0.058 , w= 0.064 
## Wasserstein dist = 0.96659155909822 
## Iteration = 3 
## Particle 1 : number of clusters=27 , EVI = 0.944 , sumVI = 0.391 , w= 0.414 
##  Particle 2 : number of clusters=28 , EVI = 0.995 , sumVI = 0.422 , w= 0.424 
##  Particle 3 : number of clusters=25 , EVI = 0.898 , sumVI = 0.146 , w= 0.162 
## Wasserstein dist = 0.958539135616072 
## Iteration = 4 
## Particle 1 : number of clusters=27 , EVI = 0.953 , sumVI = 0.434 , w= 0.456 
##  Particle 2 : number of clusters=28 , EVI = 1.002 , sumVI = 0.379 , w= 0.378 
##  Particle 3 : number of clusters=25 , EVI = 0.919 , sumVI = 0.153 , w= 0.166 
## Wasserstein dist = 0.965515185543286 
## Iteration = 5 
## Particle 1 : number of clusters=27 , EVI = 0.942 , sumVI = 0.463 , w= 0.492 
##  Particle 2 : number of clusters=27 , EVI = 1.001 , sumVI = 0.354 , w= 0.354 
##  Particle 3 : number of clusters=25 , EVI = 0.92 , sumVI = 0.142 , w= 0.154 
## Wasserstein dist = 0.959562766349092 
## Iteration = 6 
## Particle 1 : number of clusters=29 , EVI = 0.947 , sumVI = 0.494 , w= 0.522 
##  Particle 2 : number of clusters=26 , EVI = 0.999 , sumVI = 0.328 , w= 0.328 
##  Particle 3 : number of clusters=26 , EVI = 0.909 , sumVI = 0.136 , w= 0.15 
## Wasserstein dist = 0.958268319961711 
## Iteration = 7 
## Particle 1 : number of clusters=29 , EVI = 0.963 , sumVI = 0.481 , w= 0.5 
##  Particle 2 : number of clusters=25 , EVI = 1.007 , sumVI = 0.304 , w= 0.302 
##  Particle 3 : number of clusters=26 , EVI = 0.92 , sumVI = 0.182 , w= 0.198 
## Wasserstein dist = 0.967676608111123 
## Iteration = 8 
## Particle 1 : number of clusters=28 , EVI = 0.962 , sumVI = 0.523 , w= 0.544 
##  Particle 2 : number of clusters=25 , EVI = 0.997 , sumVI = 0.255 , w= 0.256 
##  Particle 3 : number of clusters=26 , EVI = 0.932 , sumVI = 0.186 , w= 0.2 
## Wasserstein dist = 0.964867984158873 
## Iteration = 9 
## Particle 1 : number of clusters=27 , EVI = 0.953 , sumVI = 0.481 , w= 0.504 
##  Particle 2 : number of clusters=27 , EVI = 0.989 , sumVI = 0.289 , w= 0.292 
##  Particle 3 : number of clusters=25 , EVI = 0.923 , sumVI = 0.188 , w= 0.204 
## Wasserstein dist = 0.957460908220007 
## Iteration = 10 
## Particle 1 : number of clusters=27 , EVI = 0.964 , sumVI = 0.482 , w= 0.5 
##  Particle 2 : number of clusters=26 , EVI = 0.987 , sumVI = 0.286 , w= 0.29 
##  Particle 3 : number of clusters=26 , EVI = 0.932 , sumVI = 0.196 , w= 0.21 
## Wasserstein dist = 0.963786856334065 
## Iteration = 11 
## Particle 1 : number of clusters=27 , EVI = 0.95 , sumVI = 0.479 , w= 0.504 
##  Particle 2 : number of clusters=26 , EVI = 0.995 , sumVI = 0.294 , w= 0.296 
##  Particle 3 : number of clusters=26 , EVI = 0.904 , sumVI = 0.181 , w= 0.2 
## Wasserstein dist = 0.954280828876736 
## Iteration = 12 
## Particle 1 : number of clusters=28 , EVI = 0.953 , sumVI = 0.545 , w= 0.572 
##  Particle 2 : number of clusters=26 , EVI = 0.987 , sumVI = 0.259 , w= 0.262 
##  Particle 3 : number of clusters=27 , EVI = 0.912 , sumVI = 0.151 , w= 0.166 
## Wasserstein dist = 0.955295323499047 
## *Running full batch after mini-batch*
## Iteration = 13 
## Particle 1 : number of clusters=28 , EVI = 0.953 , sumVI = 0.498 , w= 0.523 
##  Particle 2 : number of clusters=26 , EVI = 0.991 , sumVI = 0.275 , w= 0.277 
##  Particle 3 : number of clusters=26 , EVI = 0.918 , sumVI = 0.183 , w= 0.199 
## Wasserstein dist = 0.956424856071233 
## Iteration = 14 
## Particle 1 : number of clusters=28 , EVI = 0.951 , sumVI = 0.493 , w= 0.518 
##  Particle 2 : number of clusters=26 , EVI = 0.992 , sumVI = 0.278 , w= 0.28 
##  Particle 3 : number of clusters=26 , EVI = 0.918 , sumVI = 0.186 , w= 0.202 
## Wasserstein dist = 0.95604751395645
print(paste('Average initialization: Wass dist =',output_WASABI_avg$wass.dist))
## [1] "Average initialization: Wass dist = 0.95604751395645"
toc()
## 165.694 sec elapsed
tic()
output_WASABI_comp = WASABI(cls.draw, psm, method.init ="complete", 
                           method="salso", L=L, mini.batch = 500,
                           maxNClusters = 45, maxZealousAttempts=20,
                           max.iter= 20, extra.iter = 10, swap_countone = TRUE,
                           suppress.comment = FALSE)
## Initial particle 1 : number of clusters = 23 , EVI = 1.032 
##  Initial particle 2 : number of clusters = 22 , EVI = 1.04 
##  Initial particle 3 : number of clusters = 26 , EVI = 1.045 
## Iteration = 1 
## Particle 1 : number of clusters=28 , EVI = 0.947 , sumVI = 0.428 , w= 0.452 
##  Particle 2 : number of clusters=28 , EVI = 0.956 , sumVI = 0.163 , w= 0.17 
##  Particle 3 : number of clusters=29 , EVI = 0.984 , sumVI = 0.372 , w= 0.378 
## Wasserstein dist = 0.962473273883556 
## Iteration = 2 
## Particle 1 : number of clusters=27 , EVI = 0.949 , sumVI = 0.44 , w= 0.464 
##  Particle 2 : number of clusters=27 , EVI = 0.954 , sumVI = 0.193 , w= 0.202 
##  Particle 3 : number of clusters=30 , EVI = 0.969 , sumVI = 0.324 , w= 0.334 
## Wasserstein dist = 0.956685416838028 
## Iteration = 3 
## Particle 1 : number of clusters=27 , EVI = 0.942 , sumVI = 0.563 , w= 0.598 
##  Particle 2 : number of clusters=26 , EVI = 0.951 , sumVI = 0.192 , w= 0.202 
##  Particle 3 : number of clusters=30 , EVI = 0.932 , sumVI = 0.186 , w= 0.2 
## Wasserstein dist = 0.941559358077484 
## Iteration = 4 
## Particle 1 : number of clusters=26 , EVI = 0.939 , sumVI = 0.533 , w= 0.568 
##  Particle 2 : number of clusters=24 , EVI = 0.967 , sumVI = 0.234 , w= 0.242 
##  Particle 3 : number of clusters=31 , EVI = 0.922 , sumVI = 0.175 , w= 0.19 
## Wasserstein dist = 0.942299756511466 
## *Running full batch after mini-batch*
## Iteration = 5 
## Particle 1 : number of clusters=27 , EVI = 0.942 , sumVI = 0.527 , w= 0.559 
##  Particle 2 : number of clusters=25 , EVI = 0.967 , sumVI = 0.241 , w= 0.249 
##  Particle 3 : number of clusters=31 , EVI = 0.921 , sumVI = 0.177 , w= 0.192 
## Wasserstein dist = 0.944386961370636 
## Iteration = 6 
## Particle 1 : number of clusters=27 , EVI = 0.943 , sumVI = 0.547 , w= 0.581 
##  Particle 2 : number of clusters=25 , EVI = 0.97 , sumVI = 0.226 , w= 0.233 
##  Particle 3 : number of clusters=31 , EVI = 0.912 , sumVI = 0.17 , w= 0.186 
## Wasserstein dist = 0.943234195737468 
## Iteration = 7 
## Particle 1 : number of clusters=27 , EVI = 0.943 , sumVI = 0.559 , w= 0.593 
##  Particle 2 : number of clusters=26 , EVI = 0.97 , sumVI = 0.217 , w= 0.223 
##  Particle 3 : number of clusters=31 , EVI = 0.909 , sumVI = 0.167 , w= 0.183 
## Wasserstein dist = 0.942959596749585
print(paste('Complete initialization: Wass dist =',output_WASABI_comp$wass.dist))
## [1] "Complete initialization: Wass dist = 0.942959596749585"
toc()
## 240.105 sec elapsed
part.init = matrix(0, L, sum(C))
nclus = c(25,28,32)
for (l in c(1:L)){
  part.init[l,] = as.numeric(salso(x = cls.draw, loss = "binder", maxNClusters = nclus[l])) 
}

tic()
output_WASABI_fxd = WASABI(cls.draw, psm, method.init ="fixed", part.init = part.init,
                           method="salso", L=L, 
                           maxNClusters = 45, maxZealousAttempts=20,
                           max.iter= 30, swap_countone = TRUE, suppress.comment = FALSE)
## Initial particle 1 : number of clusters = 25 , EVI = 1.082 
##  Initial particle 2 : number of clusters = 28 , EVI = 1.067 
##  Initial particle 3 : number of clusters = 32 , EVI = 1.062 
## Iteration = 1 
## Particle 1 : number of clusters=26 , EVI = 0.911 , sumVI = 0.003 , w= 0.004 
##  Particle 2 : number of clusters=27 , EVI = 0.967 , sumVI = 0.307 , w= 0.318 
##  Particle 3 : number of clusters=28 , EVI = 0.981 , sumVI = 0.666 , w= 0.678 
## Wasserstein dist = 0.976621847585401 
## Iteration = 2 
## Particle 1 : number of clusters=25 , EVI = 0.982 , sumVI = 0.197 , w= 0.201 
##  Particle 2 : number of clusters=27 , EVI = 0.974 , sumVI = 0.382 , w= 0.392 
##  Particle 3 : number of clusters=28 , EVI = 0.958 , sumVI = 0.389 , w= 0.406 
## Wasserstein dist = 0.969011163807456 
## Iteration = 3 
## Particle 1 : number of clusters=25 , EVI = 0.999 , sumVI = 0.282 , w= 0.282 
##  Particle 2 : number of clusters=27 , EVI = 0.962 , sumVI = 0.346 , w= 0.36 
##  Particle 3 : number of clusters=28 , EVI = 0.948 , sumVI = 0.339 , w= 0.358 
## Wasserstein dist = 0.967320157940298 
## Iteration = 4 
## Particle 1 : number of clusters=26 , EVI = 1.005 , sumVI = 0.279 , w= 0.277 
##  Particle 2 : number of clusters=27 , EVI = 0.959 , sumVI = 0.349 , w= 0.364 
##  Particle 3 : number of clusters=28 , EVI = 0.944 , sumVI = 0.339 , w= 0.359 
## Wasserstein dist = 0.966111948393442 
## Iteration = 5 
## Particle 1 : number of clusters=26 , EVI = 1.005 , sumVI = 0.286 , w= 0.285 
##  Particle 2 : number of clusters=27 , EVI = 0.959 , sumVI = 0.36 , w= 0.376 
##  Particle 3 : number of clusters=28 , EVI = 0.94 , sumVI = 0.319 , w= 0.339 
## Wasserstein dist = 0.965703954987267
print(paste('Fixed initialization: Wass dist =',output_WASABI_fxd$wass.dist))
## [1] "Fixed initialization: Wass dist = 0.965703954987267"
toc()
## 408.364 sec elapsed

WASABI visualizations

WASABI provides a number of visualization tools. Let’s first consider the number of weight of the particles.

ggsummary(output_WASABI)

We can also plot the data colored by cluster membership for each particle.

# Create a matrix of normalized data and filter for two regions to draw a scatter plot
data_norm = list(phat, phat2)
r1 = 5
r2 = 9
data_norm_r1r2_list = lapply(data_norm, function(d){d[c(r1,r2),]})
data_norm_r1r2 = matrix(unlist(data_norm_r1r2_list), sum(C),2 ,byrow = TRUE)
ggscatter_grid2d(output_WASABI, data_norm_r1r2) +
  labs(x=regions.name[r1], y=regions.name[r2])

To better investigate the differences between any two particles, we can also look at the VI contribution of each point (e.g. particle 1 and particle 2):

p1 = 1
p2 = 2
VIC_p1p2 = vi.contribution(output_WASABI$particles[p1,],output_WASABI$particles[p2,])
meet_p1p2 = cls.meet(output_WASABI$particles[c(p1,p2),])

colors <- rev(sequential_hcl(5, palette = "Purple-Yellow")[1:4])
ggplot() +
  geom_point(aes(x = data_norm_r1r2[,1],
                 y = data_norm_r1r2[,2],
                 color = VIC_p1p2,
                 shape = as.factor(meet_p1p2$cls.m))) +
  theme_bw() +
  #scale_color_distiller(name = "VIC",palette = "OrRd",direction = 1)+
  scale_color_gradientn(colours = colors, transform = "sqrt", labels = function(x) sprintf("%.4f", x))+
  scale_shape_manual(values=c(1:length(unique(meet_p1p2$cls.m))))+
  guides(shape = guide_legend(title="Meet\ncluster")) +
  labs(x=regions.name[r1], y=regions.name[r2]) +
  ggtitle("VI Contribution between particle 1 and 2")

Alternative useful visualizations of MAPseq data are provided by gel plots (heat maps of the normalized data) and line plots (line plots of the normalized data). In the line plots, we filter to clusters/motifs containing at least 10 neurons, as we are not interested in projection patterns characterized by only a small group of neurons.

# Heat maps
# Illustrate clustering with heatmap of row-normalized data
# Compute VIC with particle 1
VIC_p1p1 = vi.contribution(output_WASABI$particles[1,],output_WASABI$particles[1,])
VIC_p1p2 = vi.contribution(output_WASABI$particles[1,],output_WASABI$particles[2,])
VIC_p1p3 = vi.contribution(output_WASABI$particles[1,],output_WASABI$particles[3,])
lmts = c(0, max(max(VIC_p1p1),max(VIC_p1p2),max(VIC_p1p3)))
p1_list = lapply(1:M,function(m){output_WASABI$particles[1,mouse.index==m]})
ps_p1 = heatmap_VIC(Y = data_barseq, Z = p1_list, regions.name = rownames(data_barseq[[1]]), 
           vic = VIC_p1p1,
           cluster.index = 1:length(unique(output_WASABI$particles[1,])), 
           title = paste('Particle 1',round(output_WASABI$part.weights[1],3)),
           limts = lmts)

p2_list = lapply(1:M,function(m){output_WASABI$particles[2,mouse.index==m]})
ps_p2 = heatmap_VIC(Y = data_barseq, Z = p2_list, regions.name = rownames(data_barseq[[1]]), 
           vic = VIC_p1p2,
           cluster.index = 1:length(unique(output_WASABI$particles[2,])), 
           title = paste('Particle 2',round(output_WASABI$part.weights[2],3)),
           limts = lmts)

p3_list = lapply(1:M,function(m){output_WASABI$particles[3,mouse.index==m]})
ps_p3 = heatmap_VIC(Y = data_barseq, Z = p3_list, regions.name = rownames(data_barseq[[1]]), 
           vic = VIC_p1p3,
           cluster.index = 1:length(unique(output_WASABI$particles[3,])), 
           title = paste('Particle 3',round(output_WASABI$part.weights[3],3)),
           limts = lmts)

ggarrange(ps_p1, ps_p2, ps_p3, ncol=3, nrow=1, common.legend = TRUE, legend="right")

# Heat maps
# Illustrate clustering with heatmap of row-normalized data
# Color by the group VIC
meet_p1p2 = cls.meet(output_WASABI$particles[c(1,2),])$cls.m
VIC_p1p2_group = sapply(1:sum(C), function(c){sum(VIC_p1p2[meet_p1p2 == meet_p1p2[c]])})
meet_p1p3 = cls.meet(output_WASABI$particles[c(1,3),])$cls.m
VIC_p1p3_group = sapply(1:sum(C), function(c){sum(VIC_p1p3[meet_p1p3 == meet_p1p3[c]])})
lmts = c(0, max(max(VIC_p1p1),max(VIC_p1p2),max(VIC_p1p3_group)))
ps_p1 = heatmap_VIC(Y = data_barseq, Z = p1_list, regions.name = rownames(data_barseq[[1]]), 
           vic = VIC_p1p1,
           cluster.index = 1:length(unique(output_WASABI$particles[1,])), 
           title = paste('Particle 1',round(output_WASABI$part.weights[1],3)),
           limts = lmts) +
  labs(fill="VICG")

p2_list = lapply(1:M,function(m){output_WASABI$particles[2,mouse.index==m]})
ps_p2 = heatmap_VIC(Y = data_barseq, Z = p2_list, regions.name = rownames(data_barseq[[1]]), 
           vic = VIC_p1p2_group,
           cluster.index = 1:length(unique(output_WASABI$particles[2,])), 
           title = paste('Particle 2',round(output_WASABI$part.weights[2],3)),
           limts = lmts) +
  labs(fill="VICG")

p3_list = lapply(1:M,function(m){output_WASABI$particles[3,mouse.index==m]})
ps_p3 = heatmap_VIC(Y = data_barseq, Z = p3_list, regions.name = rownames(data_barseq[[1]]), 
           vic = VIC_p1p3_group,
           cluster.index = 1:length(unique(output_WASABI$particles[3,])), 
           title = paste('Particle 3',round(output_WASABI$part.weights[3],3)),
           limts = lmts) +
  labs(fill="VICG")

ggarrange(ps_p1, ps_p2, ps_p3, ncol=3, nrow=1, common.legend = TRUE, legend="right")

# Color line plots by the VIC contribution to highlight differences between particles
VIC_p1p1_list = list()
for (m in 1:M){
  VIC_p1p1_list[[m]] = VIC_p1p1[mouse.index==m]
}
VIC_p1p2_list = list()
for (m in 1:M){
  VIC_p1p2_list[[m]] = VIC_p1p2[mouse.index==m]
}
VIC_p1p3_list = list()
for (m in 1:M){
  VIC_p1p3_list[[m]] = VIC_p1p3[mouse.index==m]
}

# Filter to large enough clusters
mouse.list = lapply(1:M, function(m){rep(as.factor(m),C[m])})
motifs_filter1 = which(table(output_WASABI$particles[1,])>10)
motifs_filter2 = which(table(output_WASABI$particles[2,])>10)
motifs_filter3 = which(table(output_WASABI$particles[3,])>10)

pl_p1_vic = projection_vic(data_barseq, mouse.list, p1_list, regions.name,motifs=motifs_filter1, VIC_p1p1_list, limts = lmts,  ncol=7) + 
  labs(title = 'Particle 1')
pl_p2_vic = projection_vic(data_barseq, mouse.list, p2_list, regions.name,motifs=motifs_filter2, VIC_p1p2_list, limts = lmts,  ncol=7) + 
  labs(title = 'Particle 2')
pl_p3_vic = projection_vic(data_barseq, mouse.list, p3_list, regions.name,motifs=motifs_filter3, VIC_p1p3_list, limts = lmts, ncol=7) + 
  labs(title = 'Particle 3')

ggarrange(pl_p1_vic, pl_p2_vic, pl_p3_vic, ncol=1, nrow=3, common.legend = TRUE, legend="right")

To label each cluster, we consider that neurons in each group project to a region if the average projection strength is greater than 0.02.

p = 1
data_norm_cbind =  t(matrix(unlist(data_norm), sum(C),R ,byrow = TRUE))
part_motif_names <- lapply(sort(unique(output_WASABI$particles[p,])),
                           function(j){
                             
                             data.j <- matrix(data_norm_cbind[,output_WASABI$particles[p,] == j],nrow = R,ncol = sum(output_WASABI$particles[p,] == j))
                             data.j.average <- apply(data.j, 1, mean)
                             
                             pp.regions <- paste(regions.name[data.j.average >= 0.02], collapse = ',')
                             pp.weight = sum(output_WASABI$particles[p,] == j)/sum(C)
                             return(data.frame(cluster = j, 
                                               pp.regions = pp.regions,
                                               pp.weight = pp.weight,
                                               pp.strength = data.j.average))
                           })

part_motif_names <- do.call(rbind, part_motif_names)
print(part_motif_names[seq(1,dim(part_motif_names)[1],11),c(1,2)])
##     cluster                                pp.regions
## 1         1                                      Amyg
## 12        2 OFC,Motor,Rstr,SSctx,Cstr,Amyg,VisIp,AudC
## 23        3                                      AudC
## 34        4                            Cstr,Amyg,AudC
## 45        5       Rstr,Cstr,Amyg,VisIp,VisC,AudC,Thal
## 56        6                                 Thal,Tect
## 67        7                                      Thal
## 78        8                                Amyg,VisIp
## 89        9                                     VisIp
## 100      10                           VisIp,VisC,Tect
## 111      11                           SSctx,Amyg,AudC
## 122      12                               SSctx,VisIp
## 133      13                               SSctx,VisIp
## 144      14                           VisIp,Thal,Tect
## 155      15                       Cstr,VisC,AudC,Thal
## 166      16                       OFC,Cstr,Amyg,VisIp
## 177      17                            Cstr,Thal,Tect
## 188      18                                 Cstr,AudC
## 199      19                                 Thal,Tect
## 210      20                                 Rstr,Cstr
## 221      21                                 Cstr,Thal
## 232      22                           Cstr,VisIp,AudC
## 243      23                                  OFC,VisC
## 254      24                           Cstr,VisIp,AudC
## 265      25                          SSctx,VisIp,AudC
## 276      26                             OFC,Rstr,Cstr
## 287      27                           VisIp,AudC,Tect
## 298      28       OFC,Motor,Rstr,SSctx,Cstr,Thal,Tect

Investigating the uncertainty of each particle

We can plot the PSM within each region of attraction/neighborhood to understand the uncertainty of each particle.

# PSM within the region of attraction of particle 1
psm_p1 = mcclust::comp.psm(cls.draw[output_WASABI$draws.assign==1,])
hpsm_p1 = superheat(psm_p1,
          pretty.order.rows = TRUE,
          pretty.order.cols = TRUE,
          heat.pal = c("white", "yellow", "red"),
          heat.pal.values = c(0,.5,1),
          membership.rows = output_WASABI$particles[1,],
          membership.cols = output_WASABI$particles[1,],
          bottom.label.text.size = 4,
          left.label.text.size = 4,
          title = "PSM within particle 1's neighborhood")

# PSM within the region of attraction of particle 2
psm_p2 = mcclust::comp.psm(cls.draw[output_WASABI$draws.assign==2,])
hpsm_p2 = superheat(psm_p2,
          pretty.order.rows = TRUE,
          pretty.order.cols = TRUE,
          heat.pal = c("white", "yellow", "red"),
          heat.pal.values = c(0,.5,1),
          membership.rows = output_WASABI$particles[2,],
          membership.cols = output_WASABI$particles[2,],
          bottom.label.text.size = 4,
          left.label.text.size = 4,
          title = "PSM within particle 2's neighborhood")

# PSM within the region of attraction of particle 3
psm_p3 = mcclust::comp.psm(cls.draw[output_WASABI$draws.assign==3,])
hpsm_p3 = superheat(psm_p3,
          pretty.order.rows = TRUE,
          pretty.order.cols = TRUE,
          heat.pal = c("white", "yellow", "red"),
          heat.pal.values = c(0,.5,1),
          membership.rows = output_WASABI$particles[3,],
          membership.cols = output_WASABI$particles[3,],
          bottom.label.text.size = 4,
          left.label.text.size = 4,
          title = "PSM within particle 3's neighborhood")

# ggarrange(hpsm_p1, hpsm_p2, hpsm_p3, ncol=1, nrow=3, common.legend = TRUE, legend="bottom")

Investigating the meet

We can also find the meet of the particles.

First, we show line plots of neurons in each meet cluster, colored by the contribution to EVI.

output_meet = cls.meet(output_WASABI$particles)
z_meet = output_meet$cls.m

motifs_filter.m = which(table(z_meet)>10)
evi.m = evi.wd.contribution(output_WASABI, z_meet)
meet_list = lapply(1:M,function(m){z_meet[mouse.index==m]})
evi.m_list = list()
for (m in 1:M){
  evi.m_list[[m]] = evi.m[mouse.index==m]
}
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m_list, ncol=5)

The posterior similarity matrix approximated by WASABI and collapsed to the meet clusters helps us to understand which meet clusters are grouped across particles.

# Compute psm of meet clusters
psm.m = psm.meet(z_meet,output_WASABI)
Km <- nrow(psm.m)
colnames(psm.m) <- 1:Km; rownames(psm.m) <- 1:Km

# compare meet with particles
i = 1
part_cl = output_WASABI$particles[i,]
tb_meettop = table(part_cl,z_meet)
lbs_top = rownames(tb_meettop)[as.factor(apply(tb_meettop, 2, which.max))]

tmp = reshape2::melt(as.matrix(as.data.frame.matrix(tb_meettop))) %>% 
  arrange(Var1) %>% filter(value > 0)

lbs_top = tmp %>% pull(Var1)
tmp = tmp %>% pull(Var2)

superheat::superheat(psm.m[tmp,tmp],
                            heat.pal = c("white", "yellow", "red"),
                            heat.pal.values = c(0,.5,1),
                            heat.lim = c(0,1), # this is important!!
                            row.title = paste('Particle',i),
                            column.title = paste('Meet'),
                            membership.rows = as.numeric(lbs_top),
                            membership.cols = tmp,
                            bottom.label.text.angle = 90,
                            bottom.label.text.size = 3,
                            left.label.text.size = 3)

#ggsave(pm$plot,filename = "psm_meet_ac.png", device = "png", 
#       width = 7.5, height = 8,units = "in", scale = 1)

Let’s investigate the sizes of the meet clusters, colored by the their total EVI contribution (sum across neurons in the meet cluster). This helps us to understand which meet clusters are stable (groups of neurons with distinct projection patterns), and which may be more uncertain.

tmp = reshape2::melt(as.matrix(as.data.frame.matrix(tb_meettop))) %>%
  arrange(Var1) %>% filter(value > 0)
evi.m.group = sapply(unique(tmp$Var2), function(m){max(sum(evi.m[z_meet==m]),0)})
evi.m.unique= sapply(unique(tmp$Var2), function(m){unique(evi.m[z_meet==m])})
df = data.frame(cluster = factor(tmp$Var2, levels = tmp$Var2 ), size = tmp$value, EVI = evi.m.group)
ggplot(df) +
  geom_col(aes(x = cluster, y = size, fill = EVI)) +
  theme_bw() +
  scale_fill_gradientn(colours = colors, transform = "sqrt", labels = function(x) sprintf("%.4f", x))+
  theme(axis.text.x = element_text(size = 10, angle = 90, vjust = 0.5),
        axis.title.x = element_text(size = 12),
        axis.title.y = element_text(size = 12,angle = 90)) +
  labs(fill = "EVICG")

To visualize the stable meet clusters, we focus on those with at least 10 neurons and the EVI by group less than 0.002.

evi.m.group_list = list()
for (m in 1:M){
  evi.m.group_list[[m]] =  evi.m.group[sort(unique(tmp$Var2), index.return = T)$ix][z_meet[mouse.index==m]]
}
motifs_filter.m = df$cluster[(evi.m.group<0.002)&(df$size>10)]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=5, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")

We can also look more carefully at some of the other meet clusters, for example those that form the noisy cluster 2 in particle 1.

# Clusters 2 of Particle 1
motifs_filter.m = tmp[tmp$Var1==2,2]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=5, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")

Another example is the meet clusters that form cluster 4 or 5 in particle 1.

# Clusters 4 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==4]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG") 

# Clusters 5 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==5]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG") 

Let’s also investigate the Thal and Tect clusters: - Cluster 19 of particle 1 has moderate projection to Thal and Tect only - Cluster 6 of particle also projects to Thal and Tect but with higher strength to Thal - Cluster 17 of particle 1 also projects to Thal and Tect but with weak strength also to Cstr

# Clusters 19 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==19]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=4, limts = c(0, max(evi.m.group))) + labs(color = "EVICG") 

# Clusters 6 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==6]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG") 

# Clusters 17 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==17]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG") 

meet_motif_names <- lapply(sort(unique(z_meet)),
                           function(j){
                             
                             data.j <- matrix(data_norm_cbind[,z_meet == j],nrow = R,ncol = sum(z_meet==j))
                             data.j.average <- apply(data.j, 1, mean)
                             
                             pp.regions <- paste(regions.name[data.j.average >= 0.02], collapse = ',')
                             pp.weight = sum(z_meet == j)/sum(C)
                             return(data.frame(cluster = j, 
                                               pp.regions = pp.regions,
                                               pp.weight = pp.weight,
                                               pp.strength = data.j.average))
                           })

meet_motif_names <- do.call(rbind, meet_motif_names)
print(meet_motif_names[seq(1,dim(meet_motif_names)[1],11),c(1,2)])
##     cluster                                pp.regions
## 1         1                                      Amyg
## 12        2      OFC,Motor,Rstr,SSctx,Cstr,Amyg,VisIp
## 23        3          Motor,SSctx,Cstr,Amyg,VisIp,AudC
## 34        4                                      AudC
## 45        5                           VisIp,VisC,AudC
## 56        6                      Cstr,Amyg,VisIp,AudC
## 67        7                                 VisC,AudC
## 78        8                            Cstr,VisC,AudC
## 89        9 Rstr,SSctx,Cstr,Amyg,VisIp,VisC,AudC,Thal
## 100      10            Cstr,VisIp,VisC,AudC,Thal,Tect
## 111      11                       Rstr,Cstr,Amyg,AudC
## 122      12                            Cstr,Amyg,AudC
## 133      13                                 Thal,Tect
## 144      14                                 Thal,Tect
## 155      15                                 Thal,Tect
## 166      16                                      Thal
## 177      17                                Amyg,VisIp
## 188      18                               SSctx,VisIp
## 199      19                               SSctx,VisIp
## 210      20                           VisIp,VisC,Tect
## 221      21                           SSctx,Amyg,AudC
## 232      22                                     VisIp
## 243      23                               SSctx,VisIp
## 254      24                            Rstr,Cstr,AudC
## 265      25                           VisIp,Thal,Tect
## 276      26                                 Cstr,AudC
## 287      27                           Rstr,SSctx,Cstr
## 298      28                       OFC,Cstr,Amyg,VisIp
## 309      29                                 Cstr,Tect
## 320      30                                 Rstr,Cstr
## 331      31                           Cstr,VisIp,AudC
## 342      32                       Cstr,VisC,AudC,Thal
## 353      33                                 Thal,Tect
## 364      34                                 Thal,Tect
## 375      35                            Cstr,Thal,Tect
## 386      36                            Cstr,Thal,Tect
## 397      37                            Cstr,Thal,Tect
## 408      38                            Cstr,Thal,Tect
## 419      39                            Cstr,Thal,Tect
## 430      40                                 Cstr,AudC
## 441      41                  Cstr,VisC,AudC,Thal,Tect
## 452      42                                 Cstr,AudC
## 463      43                                 Thal,Tect
## 474      44                                 Thal,Tect
## 485      45                                 Cstr,Thal
## 496      46                                 Cstr,Thal
## 507      47                           Cstr,VisIp,AudC
## 518      48                           Cstr,VisIp,AudC
## 529      49                                  OFC,VisC
## 540      50                               Motor,SSctx
## 551      51                          SSctx,VisIp,AudC
## 562      52                             OFC,Rstr,Cstr
## 573      53                           VisIp,AudC,Tect
## 584      54       OFC,Motor,Rstr,SSctx,Cstr,Thal,Tect